{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d8486309-8730-4bb0-ac80-a5ef5ffb6ec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import jieba\n",
    "import glob\n",
    "import torch\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7e33aeab-5c3b-4cc2-b537-e0c8ab1927f5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>花王堂区</td>\n",
       "      <td>113.548961</td>\n",
       "      <td>22.199207</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>望德堂区</td>\n",
       "      <td>113.550183</td>\n",
       "      <td>22.193721</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>大堂区</td>\n",
       "      <td>113.553647</td>\n",
       "      <td>22.188539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>风顺堂区</td>\n",
       "      <td>113.541928</td>\n",
       "      <td>22.187368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>花地玛堂区</td>\n",
       "      <td>113.552896</td>\n",
       "      <td>22.207870</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>424</th>\n",
       "      <td>四川省</td>\n",
       "      <td>104.075809</td>\n",
       "      <td>30.651239</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>425</th>\n",
       "      <td>西藏自治区</td>\n",
       "      <td>91.117525</td>\n",
       "      <td>29.647535</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>426</th>\n",
       "      <td>新疆维吾尔自治区</td>\n",
       "      <td>87.627704</td>\n",
       "      <td>43.793026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>427</th>\n",
       "      <td>云南省</td>\n",
       "      <td>102.710002</td>\n",
       "      <td>25.045806</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>428</th>\n",
       "      <td>浙江省</td>\n",
       "      <td>120.152585</td>\n",
       "      <td>30.266597</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>429 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            0           1          2\n",
       "0        花王堂区  113.548961  22.199207\n",
       "1        望德堂区  113.550183  22.193721\n",
       "2         大堂区  113.553647  22.188539\n",
       "3        风顺堂区  113.541928  22.187368\n",
       "4       花地玛堂区  113.552896  22.207870\n",
       "..        ...         ...        ...\n",
       "424       四川省  104.075809  30.651239\n",
       "425     西藏自治区   91.117525  29.647535\n",
       "426  新疆维吾尔自治区   87.627704  43.793026\n",
       "427       云南省  102.710002  25.045806\n",
       "428       浙江省  120.152585  30.266597\n",
       "\n",
       "[429 rows x 3 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# https://github.com/boyan01/ChinaRegionDistrict/blob/master/region.json\n",
    "# https://lbsyun.baidu.com/jsdemo/demo/yLngLatLocation.htm\n",
    "\n",
    "region_data = json.load(open('region.json'))\n",
    "districts_data = sum([x['districts'] for x in region_data['districts']], [])\n",
    "city_data = [[x['name'], x['center']['longitude'], x['center']['latitude']] for x in districts_data]\n",
    "city_data += [[x['name'], x['center']['longitude'], x['center']['latitude']] for x in region_data['districts']]\n",
    "city_data = pd.DataFrame(city_data)\n",
    "\n",
    "city_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f921954b-f179-42c5-b47e-ae7263aa704b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0          花王堂区\n",
       "1          望德堂区\n",
       "2           大堂区\n",
       "3          风顺堂区\n",
       "4         花地玛堂区\n",
       "         ...   \n",
       "424         四川省\n",
       "425       西藏自治区\n",
       "426    新疆维吾尔自治区\n",
       "427         云南省\n",
       "428         浙江省\n",
       "Name: 0, Length: 429, dtype: object"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "city_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "05e7545c-3031-45af-82ce-929390254f0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>406</th>\n",
       "      <td>上海市</td>\n",
       "      <td>121.473662</td>\n",
       "      <td>31.230372</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       0           1          2\n",
       "406  上海市  121.473662  31.230372"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "city_data[city_data[0] == '上海市']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5460c587-c4d1-4fc3-a850-dc58177b341a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_img_locations = pd.read_csv('训练集/图片中心经纬度.txt', sep=',', header=None)\n",
    "train_img_locations[0] = train_img_locations[0].apply(lambda x: float(str(x).split(' ')[0]))\n",
    "train_img_locations[1] = train_img_locations[1].apply(lambda x: float(str(x).split(' ')[0]))\n",
    "\n",
    "test_img_locations = pd.read_csv('初赛测试集/图片中心经纬度.txt', sep=',', header=None)\n",
    "test_img_locations[0] = test_img_locations[0].apply(lambda x: float(str(x).split(' ')[0]))\n",
    "test_img_locations[1] = test_img_locations[1].apply(lambda x: float(str(x).split(' ')[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d48fc335-6f93-4906-b3e7-6ee16876a0de",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_imgs = glob.glob('./训练集/图片/*')\n",
    "train_imgs.sort()\n",
    "train_imgs = np.array(train_imgs)\n",
    "\n",
    "test_imgs = glob.glob('./初赛测试集/图片/*')\n",
    "test_imgs.sort()\n",
    "test_imgs = np.array(test_imgs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a1f54ef6-f015-4029-abc6-2186c016ba76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import NearestNeighbors\n",
    "nbrs = NearestNeighbors(n_neighbors=40, algorithm='ball_tree').fit(test_img_locations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "420fbe30-4393-426e-b0de-f33e9a230c7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "\n",
    "from PIL import Image\n",
    "import requests\n",
    "from transformers import ChineseCLIPProcessor, ChineseCLIPModel\n",
    "\n",
    "model = ChineseCLIPModel.from_pretrained(\"OFA-Sys/chinese-clip-vit-base-patch16\")\n",
    "processor = ChineseCLIPProcessor.from_pretrained(\"OFA-Sys/chinese-clip-vit-base-patch16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eff783c9-fc25-49aa-9542-090a219a329b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/49 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.617 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "100%|██████████| 49/49 [02:29<00:00,  3.05s/it]\n"
     ]
    }
   ],
   "source": [
    "questions = open('./数据集更新/初赛测试集/问题.txt').readlines()\n",
    "results = []\n",
    "for question in tqdm(questions):\n",
    "    words = jieba.lcut(question)\n",
    "    words = [x for x in words if len(x) > 1 and not x.isdigit()]\n",
    "    city = words[0]\n",
    "    city_pic_dis, city_pic_index = nbrs.kneighbors([city_data[city_data[0].apply(lambda x: x == city)].values[0][1:]])\n",
    "    city_pic_dis = city_pic_dis[0]\n",
    "    city_pic_index = city_pic_index[0]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        # compute image feature\n",
    "        inputs = processor(images=[Image.open(x) for x in test_imgs[city_pic_index]], return_tensors=\"pt\")\n",
    "        image_features = model.get_image_features(**inputs)\n",
    "        image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)  # normalize    \n",
    "    \n",
    "        # compute text features\n",
    "        inputs = processor(text=[question], padding=True, return_tensors=\"pt\")\n",
    "        text_features = model.get_text_features(**inputs)\n",
    "        text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)  # normalize\n",
    "    \n",
    "        ids = torch.matmul(text_features, image_features.T)[0]\n",
    "        ids = ids.data.cpu().numpy()\n",
    "\n",
    "        result = city_pic_index[ids.argsort()[::-1]][:5]\n",
    "        result = [str(x) for x in result]\n",
    "        result = ','.join(result)\n",
    "        results.append(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "973456f0-8287-4977-acba-dbabe18202ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('res.csv', 'w') as up:\n",
    "    for result in results:\n",
    "        up.write(result + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e610ad48-636a-4616-a0ac-16aa1a81b132",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.41078353, 0.42506945, 0.39543888, 0.41330916, 0.382412  ,\n",
       "       0.39154297, 0.3923232 , 0.39476427, 0.40610224, 0.41434634,\n",
       "       0.40460896, 0.42052194, 0.40704408, 0.3934614 , 0.40922046,\n",
       "       0.42150655, 0.41470432, 0.436787  , 0.39174995, 0.38255706,\n",
       "       0.37405667, 0.38224554, 0.4008456 , 0.39214906, 0.41689512,\n",
       "       0.4153441 , 0.38414806, 0.45356262, 0.4016958 , 0.43198863,\n",
       "       0.41783026, 0.42809063, 0.43769678, 0.4302791 , 0.41259   ,\n",
       "       0.4011598 , 0.39734104, 0.40664643, 0.3797541 , 0.40003294],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b4162e99-73db-4a92-93e8-0e7e71831178",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.41078353"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ids[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a242c54-911d-4f68-bc51-8fbe88b92795",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
